Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scatter axis + gather axis primitives #1813

Merged
merged 3 commits into from
Feb 1, 2025
Merged

scatter axis + gather axis primitives #1813

merged 3 commits into from
Feb 1, 2025

Conversation

awni
Copy link
Member

@awni awni commented Jan 31, 2025

Add a GatherAxis and ScatterAxis primitive to support take_along_axis and put_along_axis.

The ScatterAxis supports two reduce modes (none and sum). The sum is useful for the gradient of GatherAxis. Did not add more reduce modes to manage complexity. One can always use Scatter for the other modes or we can consider adding them in the future.

Put the kernels in the JIT by default as they are pretty simple but have a lot of combinations.

Incidentally closes #1807

TODO:

  • transforms
  • benchmarks

@awni awni force-pushed the gather_scatter_axis branch from c273d38 to 6fb1fff Compare January 31, 2025 00:52
@awni
Copy link
Member Author

awni commented Jan 31, 2025

Other benchmarks:

Benchmark Pre Post
Small take 0.635 0.319 msec
Large take 439.15 22.55 msec
Small put 0.569 ms 0.317 msec
Large put 16.33 ms 14.68 msec
DSV3 Expert Score 2.3 msec 0.9 msec

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, very clean. I left only one minor comment on the copy in the CPU side.

I am wondering whether it makes sense for some of these to simply output non-contiguous arrays. Most ops in MLX output contiguous arrays ie we are greedily taking the hit as quickly as possible (with the exception of unary/binary etc). In this case we could always treat one of the two arrays (src or idx) as contiguous and adjust the output order accordingly.

Anyway, I guess it is a case rare enough to not matter so the above can be categorized as a rant.

auto& updates = inputs[2];

// Copy src into out (copy allocates memory for out)
copy(src, out, CopyType::General);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I am missing something, this needs to change to something that figures out the copy type. Same goes for normal Scatter. On the GPU side we have that already but I think it makes sense to go to common/copy.h. It would also enable donation of src which would be nice.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, nice low hanging fruit!

@@ -35,6 +35,8 @@ make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(scatter kernels/indexing.h)
make_jit_source(gather kernels/indexing.h)
make_jit_source(gather_axis)
make_jit_source(scatter_axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

kernel_name += upd.flags().row_contiguous ? "c" : "nc";
kernel_name += idx.flags().row_contiguous ? "c" : "nc";

auto lib = d.get_library(lib_name, [&]() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% not against it but are we moving towards having things inline unless they are used in multiple places in which case we move them to kernels.h? Or is it only for things that are always jitted?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, right now the pattern is inline if it's always jitted and in kernels.h otherwise (since jit_kernels.cpp only gets included when JIT is enabled).

}

lhs_indices = astype(lhs_indices, uint32, s);
rhs_indices = astype(rhs_indices, uint32, s);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏

@awni awni force-pushed the gather_scatter_axis branch from 93dbfc9 to eb9a9d5 Compare January 31, 2025 22:44
@awni awni force-pushed the gather_scatter_axis branch from eb9a9d5 to 199baf0 Compare February 1, 2025 00:02
@awni awni merged commit b7c9f1d into main Feb 1, 2025
5 checks passed
@awni awni deleted the gather_scatter_axis branch February 1, 2025 04:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Put along axis crashes on shapes that it should work for
2 participants